Use _cast_input_dtype in poly Linear.forward#3177
Conversation
The raw x.to(A.dtype) cast bypasses the disable_lora_input_dtype_casting context manager. Switch to the BaseTunerLayer._cast_input_dtype helper so poly respects the same casting controls as other tuners (fourierft, vera, etc.).
|
@Chessing234 Thanks for the PR. Could you please check all PEFT methods for the same pattern and bundle all changes into a single PR instead of submitting one for each PEFT method individually? Thanks. |
|
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread. |
|
Thanks for the feedback! I've combined this fix along with similar fixes for other PEFT methods into a single PR as requested, so I'll be closing this one out. |
|
Thanks for combining these fixes into a single PR. Please ping me when it's ready. I'll close the individual PRs then. |
|
I've checked all PEFT methods for the same pattern and bundled all the changes into a single PR, as requested. You can find the bundled update in #3280. |
Bug
poly.Linear.forwardcasts the input with a rawx.to(A.dtype)instead ofself._cast_input_dtype(...), bypassing thedisable_lora_input_dtype_castingcontext manager.
Root cause
BaseTunerLayerexposes_cast_input_dtype, which respects the module-levelflag for disabling input dtype casting. Most tuners use that helper, but
poly/layer.pystill uses the raw.to()call.Why the fix is correct
PolyLayerinherits fromBaseTunerLayer, soself._cast_input_dtypeisavailable on the
Linearsubclass.this brings poly in line with those tuners.
performs the same cast).
Change
src/peft/tuners/poly/layer.py:x = x.to(A.dtype)→x = self._cast_input_dtype(x, A.dtype).